# ~*~ coding: utf-8 ~*~
import paramiko
from io import StringIO
from Log import logger_task

class SSHConnection(object):

    def __init__(self, assets):
        # print (assets['port'])
        self.host = assets['ip']
        self.port = int(assets['port'])
        self.username = assets['username']
        self.pwd = assets['password']
        self.__k = None
        # self.private_key = paramiko.RSAKey.from_private_key_file('/root/.ssh/id_rsa')
        self.flag = 0

        if assets['private_key']:
            try:
                self.private_key = paramiko.RSAKey.from_private_key(StringIO(assets['private_key']))
                self.flag = 1
            except Exception as e:
                logger_task.info(e)
                self.flag = 0

        # logger.info(assets['private_key'])

    def connect(self):

        transport = paramiko.Transport((self.host, self.port))

        if self.flag == 1:
            # 使用密钥登录
            transport.connect(username=self.username, pkey=self.private_key)
            self.transport_private_key = transport
        else:
            transport.connect(username=self.username, password=self.pwd)
            self.transport_password = transport

        self.transport = transport

    def run_cmd(self, command):
        """
         执行shell命令,返回字典
         return {'color': 'red','res':error}或
         return {'color': 'green', 'res':res}
        :param command:
        :return:
        """
        ssh = paramiko.SSHClient()
        self.connect()

        if self.flag == 0:
            ssh._transport = self.transport_password
            # 执行命令
            stdin, stdout, stderr = ssh.exec_command(command)
            stdout = stdout.read().decode()
            stderr = stderr.read().decode()
            result = {'stdout': stdout, 'stderr': stderr}
            self.transport.close()

        if self.flag == 1:
            try:
                ssh._transport = self.transport_private_key
                # 执行命令
                stdin, stdout, stderr = ssh.exec_command(command)
                stdout = stdout.read().decode()
                stderr = stderr.read().decode()
                result = {'stdout': stdout, 'stderr': stderr}
                self.transport.close()
            except:
                ssh._transport = self.transport_password
                # 执行命令
                stdin, stdout, stderr = ssh.exec_command(command)
                stdout = stdout.read().decode()
                stderr = stderr.read().decode()
                result = {'stdout': stdout, 'stderr': stderr}
                self.transport.close()


        return result

    def upload(self, local_path, target_path):
        # 连接，上传
        sftp = paramiko.SFTPClient.from_transport(self.__transport)
        # 将location.py 上传至服务器 /tmp/test.py
        sftp.put(local_path, target_path, confirm=True)
        # print(os.stat(local_path).st_mode)
        # 增加权限
        # sftp.chmod(target_path, os.stat(local_path).st_mode)
        sftp.chmod(target_path, 0o755)  # 注意这里的权限是八进制的，八进制需要使用0o作为前缀

    def download(self, target_path, local_path):
        # 连接，下载
        sftp = paramiko.SFTPClient.from_transport(self.__transport)
        # 将location.py 下载至服务器 /tmp/test.py
        sftp.get(target_path, local_path)

    # 销毁
    # def __del__(self):
    #    self.close()
